import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_dpeo(states):
    fig, axs = plt.subplots(4, 4,figsize=(18,10))

    axs[0,0].set_ylabel(r'$P(\hat{Y}=1|White)$'+'\n'+'$-P(\hat{Y}=1|Black)$')
    axs[1, 0].set_ylabel(r'$P(\hat{Y}=1|White)$'+'\n'+'$-P(\hat{Y}=1|Black)$')
    axs[2,0].set_ylabel(r'$P(\hat{Y}=1|White,Y=1)$'+'\n'+r'$-P(\hat{Y}=1|Black,Y=1)$')
    axs[3, 0].set_ylabel(r'$P(\hat{Y}=1|White,Y=1)$'+'\n'+r'$-P(\hat{Y}=1|Black,Y=1)$')
    i = 0
    for state in states[:4]:
        ax = axs[0,i]
        a = np.load(state + '.npy')
        a = a[:, 3:5]
        a = pd.DataFrame(a)
        a['c'] = 1
        a.loc[1:, 'c'] = 2
        a.columns = ['Accuracy', 'DP', 'c']
        ax.scatter(a.loc[1:, 'Accuracy'], a.loc[1:, 'DP'])
        ax.scatter(a.iloc[0, 0], a.iloc[0, 1], s=300, c='orange', marker='*')

        ax.legend(['ID (' + state + ') evaluation', 'OOD evaluation'])
        i = i + 1
    i = 0
    for state in states[4:]:
        ax = axs[1, i]
        a = np.load(state + '.npy')
        a = a[:, 3:5]
        a = pd.DataFrame(a)
        a['c'] = 1
        a.loc[1:, 'c'] = 2
        a.columns = ['Accuracy', 'EO', 'c']
        ax.scatter(a.loc[1:, 'Accuracy'], a.loc[1:, 'EO'])
        ax.scatter(a.iloc[0, 0], a.iloc[0, 1], s=300, c='orange', marker='*')

        ax.legend(['ID (' + state + ') evaluation', 'OOD evaluation'])
        i = i + 1
    i = 0
    for state in states[:4]:
        ax = axs[2,i]
        a = np.load(state + '.npy')
        a = a[:, [3,5]]
        a = pd.DataFrame(a)
        a['c'] = 1
        a.loc[1:, 'c'] = 2
        a.columns = ['Accuracy', 'EO', 'c']
        ax.scatter(a.loc[1:, 'Accuracy'], a.loc[1:, 'EO'])
        ax.scatter(a.iloc[0, 0], a.iloc[0, 1], s=300, c='orange', marker='*')
        ax.legend(['ID (' + state + ') evaluation', 'OOD evaluation'])
        i = i + 1
    i = 0
    for state in states[4:]:
        ax = axs[3, i]
        a = np.load(state + '.npy')
        a = a[:, [3,5]]
        a = pd.DataFrame(a)
        a['c'] = 1
        a.loc[1:, 'c'] = 2
        a.columns = ['Accuracy', 'DP', 'c']
        ax.scatter(a.loc[1:, 'Accuracy'], a.loc[1:, 'DP'])
        ax.scatter(a.iloc[0, 0], a.iloc[0, 1], s=300, c='orange', marker='*')

        ax.legend(['ID (' + state + ') evaluation', 'OOD evaluation'])
        ax.set_xlabel('Accuracy')
        i = i + 1
    plt.savefig('income_dpeo.png',bbox_inches='tight')


states=['CA','TX','DE','NV','KY','FL','MO','NE']
os.chdir('income5')
plot_dpeo(states)
c=pd.DataFrame()
b=pd.DataFrame()
j=0
for i in states:
    a = np.load(i + '.npy')
    a = a[:, 3:]
    c[j]=[i,a[0][0],a[0][1],a[0][2]]
    a = a[1:, :]
    d=a.mean(axis=0)
    b[j] = [i, d[0], d[1], d[2]]
    j=j+1
b=b.T
b.to_csv('test5.csv')
c=c.T
c.to_csv('train5.csv')
os.chdir('../income50')
c=pd.DataFrame()
b=pd.DataFrame()
j=0
for i in states:
    a = np.load(i + '.npy')
    a = a[:, 3:]
    c[j]=[i,a[0][0],a[0][1],a[0][2]]
    a = a[1:, :]
    d=a.mean(axis=0)
    b[j] = [i, d[0], d[1], d[2]]
    j=j+1
b=b.T
b.to_csv('test50.csv')
c=c.T
c.to_csv('train50.csv')
os.chdir('../uncon')
c=pd.DataFrame()
b=pd.DataFrame()
j=0
for i in states:
    a = np.load(i + '_un.npy')
    a = a[:, :3]
    c[j]=[i,a[0][0],a[0][1],a[0][2]]
    a = a[1:, :]
    d=a.mean(axis=0)
    b[j] = [i, d[0], d[1], d[2]]
    j=j+1
b=b.T
b.to_csv('un_test.csv')
c=c.T
c.to_csv('un_train.csv')
